Yet, S4 has an intricate algorithm that requires a complicated
implementation for diagonal plus low rank (DPLR) state
space models (SSM). To motivate this representation, S4 considered the
case of diagonal state matrices, and outlined an
extremely simple algorithm that can be implemented in just a few lines.
However, this was not used because no diagonal SSMs were known that
could mathematically model long-range dependencies - S4’s ultimate goal.
Instead, S4 used a class of special matrices that could not be
diagonalized, but found a way to transform them into almost
diagonal form, leading to the more general DPLR representation.
However, at the end of March 2022 - an effective diagonal model was
discovered in [Diagonal State
Spaces are as Effective as Structured State Spaces ] based on
approximating S4’s matrix (DSS). This important observation allows
diagonal SSMs to be used while preserving the empirical strengths of S4!
Diagonal SSMs were further fleshed out in [On the
Parameterization and Initialization of Diagonal State Space Models ],
which implemented S4’s original diagonal algorithm combined with new
theory explaining why this particular diagonal initialization can model
long-range dependencies (S4D). The rest of this post steps through this
much simpler model, an even more structured state space for
diagonal matrices.
Part I. A Refresher on State Space Models
Part II. Diagonal State Space Models
The Diagonal SSM Algorithm - Vandermonde Matrix Multiplication
Implementing the S4D Kernel
Comparing SSM Parameterizations
The Complete S4D Layer
Part IIIa. The Central Challenge: Initialization
A Brief Refresher on S4 and HiPPO
The Diagonal HiPPO Matrix
Part IIIb. An Intuitive Understanding of SSMs
Case: 1-dimensional State
Case: Diagonal SSM
Case: General SSM
Case: HiPPO and Diagonal HiPPO
Other Diagonal Initializations
Part I of this post provides a quick summary of SSMs to define their
main computational challenge. In Part II, we step through the complete
derivation and implementation of S4D, following the original S4 paper.
Notably, the core kernel computation requires only 2 lines of
code ! Finally, Part III covers the theory of diagonal SSMs,
from how S4 originally modeled long-range dependencies, to the new
breakthroughs in initializing DSS and S4D.
Part I. A Refresher
on State Space Models
We’re going to start by taking a step back – back to the original
State Space Model (SSM) itself. The original SSM is defined over
continuous time inputs, as follows (from the original S4
paper)
[TODO: Link to original post]
The state
space model is defined by this simple equation. It maps a 1-D input
signal u(t) to an N -D latent state x(t) before projecting to a 1-D output signal
y(t) .
\begin{aligned}
x'(t) &= \boldsymbol{A}x(t) + \boldsymbol{B}u(t) \\
y(t) &= \boldsymbol{C}x(t)
\end{aligned}
Our goal is to simply use the SSM as a black-box representation
in a deep sequence model, where \boldsymbol{A}, \boldsymbol{B},
\boldsymbol{C} are parameters learned by gradient descent…
An SSM maps a input u(t) to a state
representation vector x(t) and an
output y(t) . For simplicity, we assume
the input and output are one-dimensional, and the state representation
is N -dimensional. The first equation
defines the change in x(t) over
time.
[AG: In the DSS post, Sidd’s elaboration on discretization is great
and should be in Part 1 of the Annotated S4, as they are general facts
about SSMs independent of S4/DSS. I also recommend looking at my blog post
on discretization ]
Recall also that in discrete time, the SSM is viewed as a
sequence-to-sequence map (u_k) \mapsto
(y_k) , where the sequence u_k = u(k
\Delta) represents sampling the underlying continuous u(t) with a fixed sampling interval or step
size \Delta .
Part 1 of the S4 post showed that this discretized state-space model
can be viewed as a linear RNN with a transition matrix given by \boldsymbol{\overline{A}} :
\begin{aligned}
x_{k} &= \boldsymbol{\overline{A}} x_{k-1} +
\boldsymbol{\overline{B}} u_k\\
y_k &= \boldsymbol{\overline{C}} x_k \\
\end{aligned}
Note that when \boldsymbol{A} is
diagonal, the first equation decomposes as independent 1-dimensional
recurrences over the elements of x
(Splash figure, Left )!
We then showed how we can turn the above recurrence into a
convolution because of the repetitive structure (more formally
because the recurrence is time-invariant ). We end up with the
kernel:
\begin{aligned}
y_k &= \boldsymbol{\overline{C}} \boldsymbol{\overline{A}}^k
\boldsymbol{\overline{B}} u_0 + \boldsymbol{\overline{C}}
\boldsymbol{\overline{A}}^{k-1} \boldsymbol{\overline{B}} u_1 + \dots +
\boldsymbol{\overline{C}} \boldsymbol{\overline{A}}
\boldsymbol{\overline{B}} u_{k-1} +
\boldsymbol{\overline{C}}\boldsymbol{\overline{B}} u_k
\\
y &= \boldsymbol{\overline{K}} \ast u
\end{aligned}
Recall that N denotes the state
size, or dimensionality of \boldsymbol{A} , while L denotes the sequence length.
\begin{aligned}
\boldsymbol{\overline{K}} \in \mathbb{R}^L =
(\boldsymbol{\overline{C}}\boldsymbol{\overline{B}},
\boldsymbol{\overline{C}}\boldsymbol{\overline{A}}\boldsymbol{\overline{B}},
\dots,
\boldsymbol{\overline{C}}\boldsymbol{\overline{A}}^{L-1}\boldsymbol{\overline{B}})
\end{aligned}
Problem : SSMs in deep learning have two core
challenges. The modeling challenge is finding good parameters
of the SSM, particular the state matrix \boldsymbol{A} , that can effectively model
complex interactions in sequential data such as long-range dependencies.
We defer this discussion, which is more theoretical, to Part III.
The core computational challenge of SSMs is constructing
this kernel \boldsymbol{\overline{K}}
fast. Overcoming this requires imposing structure on
the state space. Next, let’s see how diagonal SSMs provide a simple way
to do this.
Part II. Diagonal State
Space Models
Let’s now examine more closely how to compute this discretized SSM.
This part will directly follow Section 3.1 of the original S4 paper.
The fundamental bottleneck in computing the discrete-time SSM is that
it involves repeated matrix multiplication by \boldsymbol{\overline{A}} . For example,
computing \boldsymbol{\overline{K}}
naively involves L successive
multiplications by \boldsymbol{\overline{A}} , requiring O(N^2 L) operations and O(NL) space.
In other words, computing this kernel \boldsymbol{\overline{K}} is prohibitively
expensive for general state matrices \boldsymbol{A} . Getting SSMs to scale
requires finding an alternative way to computing this kernel – one that
is both efficient and that doesn’t badly restrict the expressivity of
\boldsymbol{A} . So how can we address
this?
To overcome this bottleneck, we use a structural result that allows
us to simplify SSMs.
Lemma 1. Conjugation is an equivalence relation on
SSMs (\boldsymbol{A}, \boldsymbol{B},
\boldsymbol{C}) \sim (\boldsymbol{V}^{-1} \boldsymbol{A} \boldsymbol{V},
\boldsymbol{V}^{-1}\boldsymbol{B},
\boldsymbol{C}\boldsymbol{V}) .
Proof. Write out the two SSMs with state denoted by
x and \tilde{x} respectively:
\begin{aligned}
x' &= \boldsymbol{A}x + \boldsymbol{B}u & \qquad \qquad
\qquad \tilde{x}' &=
\boldsymbol{V}^{-1}\boldsymbol{A}\boldsymbol{V}\tilde{x} +
\boldsymbol{V}^{-1}\boldsymbol{B}u \\
y &= \boldsymbol{C}x & \qquad \qquad \qquad y &=
\boldsymbol{C}\boldsymbol{V}\tilde{x}
\end{aligned}
After multiplying the right side SSM by \boldsymbol{V} , the two SSMs become identical
with x = \boldsymbol{V}\tilde{x} .
Therefore these compute the exact same operator u \mapsto y , but with a change of basis by
\boldsymbol{V} in the state x .
Why is this important? It allows replacing \boldsymbol{A} with a canonical
form such as diagonal matrices, simplifying the structure while
preserving expressivity! Ideally, this structure would simplify and
speed up the computation of the SSM kernel.
Note that this provides an immediate proof of the expressivity of
diagonal SSMs. [footnote: a more complicated version is presented as
Proposition 1 of the DSS paper] To spell it out: suppose we have a state
space with parameters (\boldsymbol{A},
\boldsymbol{B}, \boldsymbol{C}) where the matrix \boldsymbol{A} is diagonalizable - in other
words, there exists a matrix \boldsymbol{V} such that \boldsymbol{V}^{-1}\boldsymbol{A}\boldsymbol{V}
is diagonal. Then the state space (\boldsymbol{V}^{-1} \boldsymbol{A} \boldsymbol{V},
\boldsymbol{V}^{-1}\boldsymbol{B}, \boldsymbol{C}\boldsymbol{V})
is a diagonal SSM that is exactly equivalent , or in other words
defines the exact same sequence-to-sequence transformation u \mapsto y !
Furthermore, it’s well known that almost
all matrices are diagonalizable , so that diagonal SSMs are
essentially fully expressive (with a caveat that we’ll talk about in
Part III).
The
Diagonal SSM Algorithm - Vandermonde Matrix Multiplication
So what’s the computational advantage of diagonal SSMs? S4 outlined
the main idea:
Lemma 1 motivates putting \bm{A}
into a canonical form by conjugation, which is ideally more structured
and allows faster computation. For example, if \bm{A} were diagonal, the resulting
computations become much more tractable. In particular, the desired
\bm{\overline{K}} would be a
Vandermonde product which theoretically only needs
O((N+L)\log^2(N+L)) arithmetic
operations.
Let’s elaborate on this in more detail. The key idea is that when
\boldsymbol{\overline{A}} is diagonal,
the matrix power can be broken into a collection of scalar
powers, dramatically simplifying the structure of the kernel \boldsymbol{\overline{K}} . In particular, the
\ell -th element of the convolution
kernel is
\boldsymbol{\overline{K}}_\ell =
\boldsymbol{C}\boldsymbol{\overline{A}}^\ell\boldsymbol{\overline{B}} =
\sum_{n=0}^{N-1} \boldsymbol{C}_n \boldsymbol{\overline{A}}_n^\ell
\boldsymbol{\overline{B}}_n
But this can be rewritten as a single matrix-vector product:
\begin{aligned}
\boldsymbol{\overline{K}} =
\begin{bmatrix}
\boldsymbol{\overline{B}}_0 \boldsymbol{C}_0 & \dots &
\boldsymbol{\overline{B}}_{N-1} \boldsymbol{C}_{N-1}
\end{bmatrix}
\begin{bmatrix}
1 & \boldsymbol{\overline{A}}_0 &
\boldsymbol{\overline{A}}_0^2 & \dots &
\boldsymbol{\overline{A}}_0^{L-1} \\
1 & \boldsymbol{\overline{A}}_1 &
\boldsymbol{\overline{A}}_1^2 & \dots &
\boldsymbol{\overline{A}}_1^{L-1} \\
\vdots & \vdots &
\vdots & \ddots &
\vdots \\
1 & \boldsymbol{\overline{A}}_{N-1} &
\boldsymbol{\overline{A}}_{N-1}^2 & \dots &
\boldsymbol{\overline{A}}_{N-1}^{L-1} \\
\end{bmatrix}
\end{aligned}
The matrix on the right side is known as a Vandermonde
matrix , where the columns encode successive powers of \boldsymbol{\overline{A}} .
More importantly, writing the kernel in this form immediately exposes
the computational complexity! Naively, materializing the matrix requires
O(NL) space and the multiplication
takes O(NL) time. But Vandermonde
matrices are very well-studied, and it’s known that they can be
multiplied in \widetilde{O}(N+L)
operations and O(N+L) space, providing
an asymptotic efficiency improvement.
We make note of a small implementation detail: the SSM depends only
on the elementwise product \boldsymbol{C}
\circ \boldsymbol{B} . So we can assume without loss of generality
that \boldsymbol{B} = \boldsymbol{1}
and choose to either train it (as in S4(D)) or freeze it (as in DSS).
(footnote: DSS also renames \boldsymbol{C} to W in their presentation, but we find the
original notation clearer.)
Implementing the S4D Kernel
Implementing this simple version of S4 for the diagonal case is very
straightforward. As with all SSMs, the first step is to discretize the
parameters with a step size \Delta .
[Link to Post 1 ] This is much simpler for diagonal
state matrices \boldsymbol{A} , as the
discretizations normally involves matrix inverses or exponentials that
can be broken up into scalar operations.
def discretize(A, B, step, mode="zoh"):
if mode == "bilinear":
return (1+step/2*A) / (1-step/2*A), step*B / (1-step/2*A)
elif mode == "zoh":
return np.exp(step*A), (np.exp(step*A)-1)/A * B
Here we show both the Bilinear method used in S4 and HiPPO, and the
ZOH method used in other SSMs such as DSS and LMU .
(As discussed in Part 1 of the Annotated S4 [AG: if we put more about
discretization there], these are closely related and have no real
empirical difference.)
The Vandermonde matrix multiplication is almost trivial to implement
and can be applied to any discretization of a diagonal SSM.
def vandermonde(v, L, alpha):
"""
Computes v @ Vandermonde(alpha, L)
v, alpha: shape (N,)
Returns: shape (L,)
"""
V = alpha[:, np.newaxis] ** np.arange(L) # Vandermonde matrix
return (v[np.newaxis, :] @ V)[0]
def s4d_kernel(C, A, L, step):
Abar, Bbar = discretize(A, 1.0, step)
return vandermonde(C*Bbar, L, Abar).real
Finally, this Vandermonde matrix multiply can be slightly optimized.
First, computing powers \alpha^k
explicitly can be slower than exponentiating \exp(k \log(\alpha)) . Second, in the case of
ZOH discretization, which directly involves a matrix exponential, a
\log \circ \exp term can be removed,
saving a pointwise operation. Finally, materializing the full
Vandermonde matrix is unnecessary and can be optimized away to save a
lot of memory! We elaborate on this below. [Link ]
@partial(jax.jit, static_argnums=2)
def s4d_kernel_zoh(C, A, L, step):
""" A version of the kernel for B=1 and ZOH """
kernel_l = lambda l: (C * (np.exp(step*A)-1)/A * np.exp(l*step*A)).sum()
return jax.vmap(kernel_l)(np.arange(L)).ravel().real
As the original S4 paper specified, this S4D kernel is just a single
Vandermonde matrix-vector product. Just as with all SSMs, we can test
that convolving by this kernel produces the same answer as the
sequential scan.
def s4d_ssm(C, A, L, step):
N = A.shape[0]
Abar, Bbar = discretize(A, np.ones(N), step, mode="zoh")
Abar = np.diag(Abar)
Bbar = Bbar.reshape(N, 1)
Cbar = C.reshape(1, N)
return Abar, Bbar, Cbar
def test_conversion(N=8, L=16):
"""Test the equivalence of the S4D kernel with the generic SSM kernel."""
step = 1.0 / L
C = normal()(rng, (N, 2))
C = C[..., 0] + 1j * C[..., 1]
A, _, _, _ = s4.make_DPLR_HiPPO(N)
A = A[np.nonzero(A.imag > 0, size=N)]
K_ = s4d_kernel(C, A, L, step)
K = s4d_kernel_zoh(C, A, L, step)
assert np.allclose(K_, K, atol=1e-4, rtol=1e-4)
ssm = s4d_ssm(C, A, L, step)
# # Apply CNN
u = np.arange(L) * 1.0
y1 = s4.causal_convolution(u, K)
# # Apply RNN
_, y2 = s4.scan_SSM(
*ssm, u[:, np.newaxis], np.zeros((N,)).astype(np.complex64)
)
assert np.allclose(y1, y2.reshape(-1).real, atol=1e-4, rtol=1e-4)
Comparing SSM
Parameterizations
With all these different SSM methods floating around, let’s quickly
compare some versions of SSMs to understand their similarities and
differences. The different parameterizations are the full S4 (for DPLR
matrices), S4D (the diagonal case of S4, presented above), and DSS (an
alternate version of diagonal matrices).
S4. First, let’s revisit once more the main point of
S4’s algorithm, which dramatically improved the efficiency of computing
the SSM kernel for DPLR matrices.
For state dimension N and sequence
length L , computing the latent state
requires O(N^2 L) operations and O(NL) space - compared to a \Omega(L+N) lower bound for both. […] S4
reparameterizes the structured state matrices \bm{A} from HiPPO by decomposing them as the
sum of a low-rank and normal term […] ultimately reducing to a
well-studied and theoretically stable Cauchy kernel. This results in
\widetilde{O}(N+L) computation and
O(N+L) memory usage, which is
essentially tight for sequence models.
In other words, all of S4’s complicated algorithm was to reduce the
DPLR SSM kernel to a Cauchy matrix
multiplication which is well-studied and fast.
S4D. Notice that the S4D algorithm is very similar,
ultimately reducing to a Vandermonde matrix multiplication which has the
same asymptotic efficiency. In fact, this is no surprise - Vandermonde
matrices and Cauchy matrices are very closely related, and have
essentially identical computational complexities because they can be
easily transformed to one another. [AG: how to add citation?] It’s neat
that generalizing the diagonal case to diagonal plus low-rank simply
reduces to a slightly different, but computationally equivalent, linear
algebra primitive!
We note that in practice, the near-linear \widetilde{O}(N+L) time algorithm for
Vandermonde and Cauchy matrices may be less efficient than naively doing
the O(NL) summation due to hardware
efficiency. However, exposing the structure of Vandermonde and Cauchy
matrices allows the kernels to be written in a way that avoids
materializing the full matrix (as our code above does, leveraging JAX’s
clever jit compilation), reducing the space complexity from O(NL) to O(N+L) .
DSS. Finally, DSS presented a slightly different
version of the S4D algorithm, which was specialized to the ZOH
discretization and introduced a softmax that normalizes
over the sequence length. This was introduced to potentially stabilize
the case when \boldsymbol{A} can have
positive eigenvalues, but has some disadvantages including being
somewhat more complicated and less efficient, and calibrated only to a
particular sequence length. A more in depth comparison is discussed in
the S4D paper.
The Complete S4D Layer
With the core convolutional kernel \boldsymbol{\overline{K}} in place, we’re
ready to put the S4D layer together!
class S4DLayer(nn.Module):
N: int
l_max: int
decode: bool = False
lr = {
"A_re": 0.1,
"A_im": 0.1,
"log_step": 0.1,
}
def setup(self):
# Learned Parameters
hippo_A_real_initializer, hippo_A_imag_initializer, _, _ = s4.hippo_initializer(self.N)
self.A_re = self.param("A_re", hippo_A_real_initializer, (self.N,))
self.A_im = self.param("A_im", hippo_A_imag_initializer, (self.N,))
self.A = np.clip(self.A_re, None, -1e-4) + 1j*self.A_im
self.C = self.param("C", normal(), (self.N, 2))
self.C = self.C[..., 0] + 1j * self.C[..., 1]
self.D = self.param("D", nn.initializers.ones, (1,))
self.step = np.exp(
self.param("log_step", s4.log_step_initializer(), (1,))
)
if not self.decode:
self.K = s4d_kernel_zoh(self.C, self.A, self.l_max, self.step)
else:
# FLAX code to ensure that we only compute discrete once during decoding.
def init_discrete():
return s4d_ssm(self.C, self.A, self.l_max, self.step)
ssm_var = self.variable("prime", "ssm", init_discrete)
if self.is_mutable_collection("prime"):
ssm_var.value = init_discrete()
self.ssm = ssm_var.value
# RNN Cache
self.x_k_1 = self.variable(
"cache", "cache_x_k", np.zeros, (self.N,), np.complex64
)
def __call__(self, u):
if not self.decode:
return s4.causal_convolution(u, self.K) + self.D * u
else:
x_k, y_s = s4.scan_SSM(*self.ssm, u[:, np.newaxis], self.x_k_1.value)
if self.is_mutable_collection("cache"):
self.x_k_1.value = x_k
return y_s.reshape(-1).real + self.D * u
S4DLayer = s4.cloneLayer(S4DLayer)
The core of the S4D layer is the same as the traditional SSM layer
defined in the first part of the post. We define our learnable weights
C then call the kernel code written
above as a convolution during training. Finally, during discrete
decoding, we use the initial recurrence computed above.
… and that’s all folks! S4D is dramatically more easy to understand
and compact (50 LoC!) than S4, with an extremely structured state space
that reduces to a single Vandermonde product. Together with the new
theoretical insights in the next section, we can build a model that is
almost as expressive and performant as S4.
Part IIIa. The
Central Challenge: Initialization
The final piece in the above code left unexplained so far is also the
most important: how should we initialize the SSM parameters, in
particular the diagonal matrix \boldsymbol{A} ?
In order to understand this key breakthrough that made diagonal SSMs
perform well, we have to briefly revisit the motivation and theoretical
interpretation of S4. This is the only part of this post that requires
some mathematical background, but is optional: the entire algorithm is
already fully contained in Parts I and II.
The initialization is given by the line
hippo_A_initializer, which is the diagonal part of the DPLR
representation of S4’s HiPPO matrix. For the rest of this post, we give
some historical context and intuition for this initialization.
A Brief Refresher on S4 and
HiPPO
Recall that the critical question for state space models is how to
parameterize and initialize the state matrix \boldsymbol{A} in a way that can (i) be
computed efficiently and (ii) model complex interactions in the data
such as long range dependencies.
Although the diagonal SSM algorithm presented above is very simple
and efficient, iit’s actually extremely difficult to find a diagonal
\boldsymbol{A} that performs well!
As a refresher, S4’s motivation was to instead use a particular
formula for the \boldsymbol{A} matrix
called a HiPPO matrix
that has a mathematical interpretation of memorizing the history of the
input u(t) . This theory is what gives
S4 its remarkable performance on long sequence modeling, described in
this figure from [How to Train Your
HiPPO ].
An
illustration of HiPPO for L=10000,
N=64 .
Here, an input signal u(t)
(Black ) is processed by the HiPPO operator x' = \boldsymbol{A}x + \boldsymbol{B}u
(Blue ) for 10000 steps,
maintaining a state x(t) \in
\mathbb{R}^{64} . At all times, the current state represents a
compression of the history of u(t) and
can be linearly projected to approximately reconstruct it
(Red ). This approximation is optimal with respect to an
exponentially-decaying measure (Green ).
The primary challenge that S4 addressed is how to efficiently compute
with this matrix \boldsymbol{A} . The
HiPPO matrix has a simple closed-form formula:
\boldsymbol{A} =
-
\begin{bmatrix}
1 & 0 & 0 & 0 \\
(3 \cdot 1)^{\frac{1}{2}} & 2 & 0 & 0 \\
(5 \cdot 1)^{\frac{1}{2}} & (5 \cdot 3)^{\frac{1}{2}} & 3 &
0 \\
(7 \cdot 1)^{\frac{1}{2}} & (7 \cdot 3)^{\frac{1}{2}} & (7 \cdot
5)^{\frac{1}{2}} & 4 \\
\end{bmatrix}
Note that this matrix is not diagonal, but it is diagonalizable (with
eigenvalues -1, -2, -3, \dots ) - so we
can hope to apply Lemma 1. Alas, S4 showed that this doesn’t work
because \boldsymbol{V} has
exponentially large entries.
Unfortunately, the naive application of diagonalization does not work
due to numerical issues. […] The ideal scenario is when the matrix \bm{A} is diagonalizable by a perfectly
conditioned (i.e., unitary) matrix. By the Spectral Theorem of linear
algebra, this is exactly the class of normal matrices .
However, this class of matrices is restrictive; in particular, it does
not contain the HiPPO matrix.
This discussion highlights the key limitation of diagonal
SSMs : although expressive in theory (algebraically ),
they are not necessarily expressive in practice (numerically ).
To circumvent this, S4 discovered a new way to put a matrix in
almost diagonal form, while only needing to conjugate by
unitary matrices which are perfectly stable.
\begin{aligned}
\boldsymbol{A} &= \boldsymbol{A^{(N)}} -
\boldsymbol{P}\boldsymbol{P}^\top \\
&= -
\begin{bmatrix}
\frac{1}{2} & -\frac{1}{2}(3 \cdot 1)^{\frac{1}{2}} &
-\frac{1}{2}(5 \cdot 1)^{\frac{1}{2}} & -\frac{1}{2}(7 \cdot
1)^{\frac{1}{2}} \\
\frac{1}{2}(3 \cdot 1)^{\frac{1}{2}} & \frac{1}{2} &
-\frac{1}{2}(5 \cdot 3)^{\frac{1}{2}} & -\frac{1}{2}(7 \cdot
3)^{\frac{1}{2}} \\
\frac{1}{2}(5 \cdot 1)^{\frac{1}{2}} & \frac{1}{2}(5 \cdot
3)^{\frac{1}{2}} & \frac{1}{2} & -\frac{1}{2}(7 \cdot
5)^{\frac{1}{2}} \\
\frac{1}{2}(7 \cdot 1)^{\frac{1}{2}} & \frac{1}{2}(7 \cdot
3)^{\frac{1}{2}} & \frac{1}{2}(7 \cdot 5)^{\frac{1}{2}} &
\frac{1}{2} \\
\end{bmatrix}
- \frac{1}{2}
\begin{bmatrix}
1^{\frac{1}{2}} \\
3^{\frac{1}{2}} \\
5^{\frac{1}{2}} \\
7^{\frac{1}{2}} \\
\end{bmatrix}
\begin{bmatrix}
1^{\frac{1}{2}} \\
3^{\frac{1}{2}} \\
5^{\frac{1}{2}} \\
7^{\frac{1}{2}} \\
\end{bmatrix}^\top
\end{aligned}
As discussed in Part 1 of the Annotated S4 [Link], the first
component \boldsymbol{A}^{(N)} is a
normal matrix which is unitarily diagonalizable, hence \boldsymbol{A} is unitarily equivalent to a
DPLR matrix. This led to all of the fancy machinery to compute the DPLR
kernel that S4 introduced.
The Diagonal HiPPO Matrix
Finally, we can describe the key fact that made diagonal SSMs work.
DSS’s core contribution is showing that simply masking out the
low-rank portion of the HiPPO matrix results in a diagonal
matrix that empirically performs almost as well as S4. This is the key
“fork in the road” between the original S4 paper, and the follow-up
diagonal SSMs which all use this diagonal approximation of the HiPPO
matrix.
It can be hard to appreciate how surprising and subtle this fact is.
It’s important to note that writing the HiPPO matrix in DPLR form was
S4’s main contribution, but this form was purely for
computational purposes . In other words, the diagonal
and low-rank portions by themselves should have no mathematical
meaning . In fact, other follow-ups that generalize and explain S4 introduce different
variants of S4 that all have a DPLR representation, but where dropping
the low-rank term to convert it to a diagonal matrix performs much
worse.
It turns out that this particular matrix is extremely special, and
the diagonal HiPPO matrix does have a theoretical
interpretation. Dropping the low-rank term - leaving only the normal
term \boldsymbol{A}^{(N)} - has the
same dynamics as \boldsymbol{A} in the
limit as the state size N \to
\infty . This is a pretty remarkable fact proved in the S4D
paper, and honestly still seems like an incredible coincidence. In the
rest of this post, we’ll unpack this fact and try to give more intuition
for SSMs.
Part IIIb: An
Intuitive Understanding of SSMs
We’ll close out this blog post with some discussion on how to think
about SSMs, illustrated through diagonal SSMs. We’ll focus on intuition
for the following question:
Q: How should we interpret the convolution kernel of a state
space model ?
Case: 1-dimensional state
Let’s start with the case of an SSM with N=1 . We’ll write lowercase \bm{a} and \bm{b} to emphasize that they’re scalars. The
state x(t) is then a scalar function
that satisfies a linear ODE, which is elementary to solve. The original
SSM state equation
\begin{aligned}
\frac{d}{dt} x(t) &= \bm{a} x(t) + \bm{b} u(t) \\
\end{aligned}
can be multiplied by a simple term (called an integrating
factor ) to produce a simpler ODE,
\begin{aligned}
\frac{d}{dt} e^{-t\bm{a}} x(t) &= -\bm{a} e^{-t\bm{a}} x(t) +
e^{-t\bm{a}}x'(t)
\\&= e^{-t\bm{a}} \bm{b} u(t)
.
\end{aligned}
This can be explicitly integrated
\begin{aligned}
e^{-t\bm{a}} x(t) &= \int_0^t e^{-s\bm{a}} \bm{b} u(s) \; ds
\\
\end{aligned}
which yields a closed formula for the state
\begin{aligned}
x(t) &= \int_0^t e^{(t-s)\bm{a}} \bm{b} u(s) \; ds
\\&= e^{t\bm{a}} \ast u(t)
\end{aligned}
This shows that the state x(t) is
just the (causal) convolution of the input u(t) with an exponential decaying kernel
e^{t\bm{a}} !
Case: Diagonal SSM
When \bm{A} is diagonal, the
equation x'(t) = \bm{A} x(t) + \bm{B}
u(t) can simply be broken into N
independent scalar SSMs, as illustrated in the splash figure
(Left ). Therefore each element of the state, which we can
denote x_n(t) , convolves the input by
the exponentially decaying kernel e^{t\bm{A}_n}\bm{B}_n . These kernels are
illustrated in the splash figure (Right ), visualized again
here:
The N different scalar equations can
be vectorized to just say that
x(t) = (e^{t\bm{A}} \bm{B}) \ast u(t)
The above figure very concretely plots (the real part of) the 4
functions e^{t\bm{A}} \bm{B} for
\bm{A} =
\begin{bmatrix}
-\frac{1}{2} + i \pi & 0 & 0 & 0 \\
0 & -\frac{1}{2} + i 2\pi & 0 & 0 \\
0 & 0 & -\frac{1}{2} + i 3\pi & 0 \\
0 & 0 & 0 & -\frac{1}{2} + i 4\pi \\
\end{bmatrix}
\qquad \bm{B} =
\begin{bmatrix}
1 \\ 1 \\ 1 \\ 1 \\
\end{bmatrix}
For example, the first basis function (Blue ) is just e^{-\frac{1}{2}t} e^{i\pi t} . Notice how the
real part of \bm{A} controls
the decay of these functions (dotted lines ), while the
imaginary part controls the frequency.
What about y(t) ? This is just a
linear combination of the state, y(t) = \bm{C}
x(t) . But by linearity of convolution, we can push \bm{C} inside:
y(t) = \bm{C} (e^{t\bm{A}} \bm{B} \ast u(t)) = (\bm{C} e^{t\bm{A}}
\bm{B}) \ast u(t)
So the entire SSM is equivalent to a 1-D convolution where the
kernel is just \bm{C} e^{t\bm{A}}
\bm{B} . We interpret this kernel as a linear combination
of the N basis kernels e^{t\bm{A}} \bm{B} .
Case: General SSM
The same derivation and formula for the convolution kernel actually
holds in the case of non-diagonal state matrices \bm{A} , only now it involves a matrix
exponential instead of scalar exponentials. We still interpret it
the same way: e^{t\bm{A}} \bm{B} is a
vector of N different basis
kernels , and the overall convolution kernel \bm{C} e^{t\bm{A}} \bm{B} is a linear
combination of these basis functions.
Case: HiPPO and Diagonal
HiPPO
Equipped with this understanding, we can now try to understand the
HiPPO matrix better. Although we’ve been focusing on \bm{A} , which is the more important matrix,
HiPPO actually provides exact formulas for \bm{A} and \bm{B} . So with the above interpretation,
HiPPO provides a specific set of basis functions , and the
parameter \bm{C} then learns a weighted
combination of these to use as the final convolution kernel. For the
particular (\bm{A}, \bm{B}) that S4
uses, each basis function actually has a closed-form formula as
exponentially-warped Legendre polynomials L_n(e^{-t}) . Intuitively, the state x(t) of S4 is convolving the input by each of
these very smooth, infinitely-long kernels, which gives rise to its
long-range modeling abilities.
Finally, what about the basis for S4D using this diagonal
approximation to the HiPPO matrix? Let’s plot e^{t \bm{A}^{(N)}} \frac{\bm{B}}{2} for N=256 (Left ) and N=1024 (Right ):
[AG: TODO clean up figures e.g. make x/y axes the same]
We can see that this matrix \boldsymbol{A}^{(N)} (a dense, diagonalizable
matrix) generates noisy approximations to the same kernels as \bm{A} (a triangular,
hard-to-diagonalize matrix) that are exactly equal as N\to\infty . This is what we meant by saying
the diagonal-HiPPO matrix is a perfect approximation to the original
HiPPO matrix, which really seems like a remarkable mathematical
coincidence!
Other Diagonal
Initializations
Inspired by the HiPPO-diagonal matrix, S4D proposes several new
initializations with simpler formulas where the real part is fixed to
-\frac{1}{2} and the imaginary part
follows a polynomial law. In fact, the example in the previous section
with \boldsymbol{A}_n = -\frac{1}{2} + i \pi
n is the simplest variant, called S4D-Lin
because the imaginary part scales linearly in n .
However, for now, it seems like the full HiPPO matrix is core to S4’s
long-range modeling abilities, and the diagonal-HiPPO initialization
also seems empirically best among diagonal SSMs. The following table
shows results for several variants of S4 and S4D with various (\boldsymbol{A}, \boldsymbol{B})
initializations, using a slightly improved architecture from the
original S4 paper.
Long Range Arena (LRA) . (Top ) S4 variants with
DPLR representation; LegS is equivalent to the original S4 \bm{A} . (Middle ) S4D variants; LegS
is the diagonal-HiPPO \bm{A}^{(D)} .
(Bottom ) Previous results, including the original S4 with
different hyperparameter settings.
We see that all of them perform very well in general, and the very
simple S4D-Lin initialization is even best on several of the 5 main
tasks. However, the original S4 initialization and its diagonal
approximation are thus far the only ones that can solve Path-X.
Open challenge : are there alternative SSMs not based
on HiPPO that can get to 90% on Path-X?
Conclusion
In writing this post, we hope that fleshing out the details of these
models can lower the barrier to understanding S4 and inspire future
ideas in this area. The introduction of new diagonal state space
models show the potential of SSMs as sequence models that can
be incredibly powerful, yet quite simple to understand and implement.
However, there’s much left to understand here, and we believe that
perhaps even simpler and better models may be uncovered!